package async

import (
	"context"
	"fmt"
	"sync"

	"golang.org/x/sync/errgroup"
)

type groupHandleFunc[T any] func() (T, error)

type GroupResult[T any] struct {
	Flag string
	Data T
}

type registerTask[T any] struct {
	flag   string
	handle groupHandleFunc[T]
}

type Group[T any] struct {
	current       int
	count         int
	ctx           context.Context
	mux           sync.RWMutex
	group         *errgroup.Group
	registerTasks []registerTask[T]
	resultChan    chan *GroupResult[T]
}

func NewGroup[T any](count int) *Group[T] {
	group, ctx := errgroup.WithContext(context.Background())
	return &Group[T]{
		count:      count,
		ctx:        ctx,
		mux:        sync.RWMutex{},
		group:      group,
		resultChan: make(chan *GroupResult[T], count),
	}
}

func (ag *Group[T]) addOne() {
	ag.mux.Lock()
	defer ag.mux.Unlock()

	ag.current += 1
}

// Register 延时 调用Collect触发执行
func (ag *Group[T]) Register(handle groupHandleFunc[T], flag string) *Group[T] {
	ag.registerTasks = append(ag.registerTasks, registerTask[T]{flag: flag, handle: handle})
	return ag
}

// 立即执行
func (ag *Group[T]) Add(handle groupHandleFunc[T], flag string) *Group[T] {
	ag.addOne()

	ag.group.Go(func() error {
		select {
		case <-ag.ctx.Done():
			return nil
		default:
			if ag.current > ag.count {
				return fmt.Errorf("Group limit: %d, current: %d", ag.count, ag.current)
			}
			result, err := handle()
			if err != nil {
				return err
			}
			ag.resultChan <- &GroupResult[T]{Flag: flag, Data: result}
			return nil
		}
	})
	return ag
}

// Collect 等待执行结果
func (ag *Group[T]) Collect() (<-chan *GroupResult[T], error) {
	defer close(ag.resultChan)

	for _, registerTask := range ag.registerTasks {
		ag.Add(registerTask.handle, registerTask.flag)
	}

	if err := ag.group.Wait(); err != nil {
		return nil, err
	}

	return ag.resultChan, nil
}
